import numpy as np
import scipy.signal
import torch
from torch import nn
import torch.nn.functional as F


class FIR(nn.Module):
    def __init__(self, numtaps, bands, desired, fs=2):
        super().__init__()

        if numtaps % 2 == 0:
            print(f"warning: numtaps must be odd, increasing numtaps to {numtaps + 1}")
            numtaps += 1

        a = scipy.signal.firls(numtaps, bands, desired, fs=fs)

        self.weight = torch.from_numpy(a.astype(np.float32))

    def forward(self, x):
        num_channels = x.size(1)

        weight = torch.repeat_interleave(self.weight.view(1, 1, -1), num_channels, 0)

        y = F.conv1d(x, weight, groups=num_channels)

        return y
